use alloc::ffi::CString;

use super::{ExecutionProvider, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};

/// [MIGraphX execution provider](https://onnxruntime.ai/docs/execution-providers/MIGraphX-ExecutionProvider.html) for
/// hardware acceleration with AMD GPUs.
#[derive(Debug, Default, Clone)]
pub struct MIGraphX {
	device_id: i32,
	enable_fp16: bool,
	enable_int8: bool,
	use_native_calibration_table: bool,
	int8_calibration_table_name: Option<CString>,
	save_model_path: Option<CString>,
	load_model_path: Option<CString>,
	exhaustive_tune: bool
}

super::impl_ep!(MIGraphX);

impl MIGraphX {
	/// Configures which device the EP should use.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_device_id(0).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_device_id(mut self, device_id: i32) -> Self {
		self.device_id = device_id;
		self
	}

	/// Enable FP16 quantization for the model.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_fp16(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_fp16(mut self, enable: bool) -> Self {
		self.enable_fp16 = enable;
		self
	}

	/// Enable 8-bit integer quantization for the model. Requires
	/// [`MIGraphX::with_int8_calibration_table`] to be set.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_int8(mut self, enable: bool) -> Self {
		self.enable_int8 = enable;
		self
	}

	/// Configures the path to the input calibration data for int8 quantization.
	///
	/// The `native` parameter specifies the format the calibration data is in - `true` for native int8 format, `false`
	/// for the JSON dump format.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_int8(true).with_int8_calibration_table("...", false).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_int8_calibration_table(mut self, table_name: impl AsRef<str>, native: bool) -> Self {
		self.use_native_calibration_table = native;
		self.int8_calibration_table_name = Some(CString::new(table_name.as_ref()).expect("invalid string"));
		self
	}

	/// Save the compiled MIGraphX model to the given path.
	///
	/// The compiled model can then be loaded in subsequent runs with [`MIGraphX::with_load_model`].
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_save_model("./compiled_model.mxr").build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_save_model(mut self, path: impl AsRef<str>) -> Self {
		self.save_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
		self
	}

	/// Load the compiled MIGraphX model (previously generated by [`MIGraphX::with_save_model`]) from
	/// the given path.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_load_model("./compiled_model.mxr").build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_load_model(mut self, path: impl AsRef<str>) -> Self {
		self.load_model_path = Some(CString::new(path.as_ref()).expect("invalid string"));
		self
	}

	/// Enable exhaustive tuning; trades loading time for inference performance.
	///
	/// ```
	/// # use ort::{ep, session::Session};
	/// # fn main() -> ort::Result<()> {
	/// let ep = ep::MIGraphX::default().with_exhaustive_tune(true).build();
	/// # Ok(())
	/// # }
	/// ```
	#[must_use]
	pub fn with_exhaustive_tune(mut self, enable: bool) -> Self {
		self.exhaustive_tune = enable;
		self
	}
}

impl ExecutionProvider for MIGraphX {
	fn name(&self) -> &'static str {
		"MIGraphXExecutionProvider"
	}

	fn supported_by_platform(&self) -> bool {
		cfg!(any(all(target_os = "linux", target_arch = "x86_64"), all(target_os = "windows", target_arch = "x86_64")))
	}

	#[allow(unused, unreachable_code)]
	fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
		#[cfg(any(feature = "load-dynamic", feature = "migraphx"))]
		{
			use core::ptr;

			use crate::{AsPointer, ortsys};

			let options = ort_sys::OrtMIGraphXProviderOptions {
				device_id: self.device_id,
				migraphx_fp16_enable: self.enable_fp16.into(),
				migraphx_int8_enable: self.enable_int8.into(),
				migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
				migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_load_compiled_model: self.load_model_path.is_some().into(),
				migraphx_load_model_path: self.load_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_save_compiled_model: self.save_model_path.is_some().into(),
				migraphx_save_model_path: self.save_model_path.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null),
				migraphx_exhaustive_tune: self.exhaustive_tune
			};
			ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.ptr_mut(), &options)?];
			return Ok(());
		}

		Err(RegisterError::MissingFeature)
	}
}
